"""Tests for pyi_checker definitions.

These are sanity checks to make sure from_node definitions work.
"""

import textwrap

from pytype.tools.pyi_checker import definitions
from pytype.tools.pyi_checker import test_utils as utils
from typed_ast import ast3
import unittest


class DefinitionFromNodeTest(unittest.TestCase):

  def test_function_basic(self):
    expected = utils.make_func(name="foo")
    actual = utils.func_from_source("def foo(): pass")
    self.assertEqual(expected, actual)

  def test_function_all_args(self):
    func_str = "def foo(arg1, arg2=2, *var, key1, key2=2, **keys): pass"
    expected = utils.make_func(
        name="foo",
        params=[utils.make_arg("arg1", col_offset=8),
                utils.make_arg("arg2", col_offset=14, has_default=True)],
        vararg=utils.make_arg("var", col_offset=23),
        kwonlyargs=[utils.make_arg("key1", col_offset=28),
                    utils.make_arg("key2", col_offset=34, has_default=True)],
        kwarg=utils.make_arg("keys", col_offset=44))
    actual = utils.func_from_source(func_str)
    self.assertEqual(expected, actual)

  def test_function_decorators(self):
    func_str = textwrap.dedent("""\
        @some_decorator
        @another_decorator
        def test():
          pass
        """)
    expected = utils.make_func(
        name="test",
        lineno=1,  # Decorators all count as 1 line.
        decorators=["some_decorator", "another_decorator"])
    actual = utils.func_from_source(func_str)
    self.assertEqual(expected, actual)

  def test_function_async(self):
    func_str = "async def test(): pass"
    expected = utils.make_func(
        name="test",
        col_offset=6,  # "async" doesn't count as part of definition.
        is_async=True)
    actual = utils.func_from_source(func_str)
    self.assertEqual(expected, actual)

  def test_function_sigs_differ(self):
    # Since definitions are dataclasses, they're compared is if they were tuples
    # of elements. Therefore these two Function instances should be considered
    # different, since they have different keyword-only argument orders.
    # The richer comparison performed by the pyi checker would consider them
    # the same, since keyword-only argument order doesn't really matter.
    func1 = utils.func_from_source("def foo(a, *c, d, e): pass")
    func2 = utils.func_from_source("def foo(a, *c, e, d): pass")
    self.assertNotEqual(func1, func2)

  def test_variable(self):
    expected = definitions.Variable(name="x", source="", lineno=1, col_offset=0)
    actual = utils.var_from_source("x")
    self.assertEqual(expected, actual)

  def test_class_members(self):
    # Classes are generated by the DefinitionFinder visitor passing in the
    # class def, the list of methods, the list of fields and any nested classes.
    # This test represents this by generating each piece separately.
    class_stmt = textwrap.dedent("""\
      class A:
        class_field = 3
        def __init__(self, arg):
          self.instance_field = arg
        def a_method(self, arg):
          return self.instance_field + arg
        class _simple_nested_cls:
          pass
      """)
    expected_methods = {
        "__init__": utils.make_func(
            name="__init__",
            lineno=3,
            col_offset=2,
            params=[utils.make_arg("self", lineno=3, col_offset=15),
                    utils.make_arg("arg", lineno=3, col_offset=21)]
        ),
        "a_method": utils.make_func(
            name="a_method",
            lineno=5,
            col_offset=2,
            params=[utils.make_arg("self", lineno=5, col_offset=15),
                    utils.make_arg("arg", lineno=5, col_offset=21)]
        ),
    }
    expected_fields = {
        "class_field": definitions.Variable(
            "class_field", source="", lineno=2, col_offset=2),
        "instance_field": definitions.Variable(
            "instance_field", source="", lineno=4, col_offset=4)
    }
    expected_nests = {
        "_simple_nested_cls": definitions.Class(
            name="_simple_nested_cls", source="",
            lineno=7, col_offset=2,
            bases=[], keyword_bases=[], decorators=[],
            fields={}, methods={}, nested_classes={})
    }
    expected_class = definitions.Class(
        name="A",
        source="",
        lineno=1,
        col_offset=0,
        bases=[],
        keyword_bases=[],
        decorators=[],
        fields=expected_fields,
        methods=expected_methods,
        nested_classes=expected_nests)
    # We have to pull apart the parsed class by hand.
    node = utils.parse_stmt(class_stmt)
    classfield, init, method, nested = node.body
    # init.body[0] is an ast3.Attribute, and for class fields, there needs to
    # be a definition for the attr instead of the value. So the Variable
    # definition isn't created with from_node, because attr is a str.
    instance_field = init.body[0].targets[0]
    actual_fields = {
        "class_field": definitions.Variable.from_node(classfield.targets[0]),
        "instance_field": definitions.Variable(
            instance_field.attr, "", instance_field.lineno,
            instance_field.col_offset),
    }
    actual_methods = {
        "__init__": definitions.Function.from_node(init),
        "a_method": definitions.Function.from_node(method),
    }
    actual_nests = {
        "_simple_nested_cls": definitions.Class.from_node(nested, {}, {}, {}),
    }
    actual_class = definitions.Class.from_node(
        node, actual_fields, actual_methods, actual_nests)
    self.assertEqual(expected_class, actual_class)


if __name__ == "__main__":
  unittest.main()

